JuliaCon 2022
Patrick Altmeyer
CounterfactualExplanations.jl.From human to data-driven decision-making …
… where black boxes are recipe for disaster.
“You cannot appeal to (algorithms). They do not listen. Nor do they bend.”
— Cathy O’Neil in Weapons of Math Destruction, 2016
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
Let \(\mathcal{D}={(x,y)}\) denote our true population of input-output pairs. Then we want to find a subsample of the true population
\[\mathcal{D}_n \subset \mathcal{D}\]
such that
\[\mathcal{D}_n \sim p(\mathcal{D})\]
Lots of open questions and work to be done, but not here and today.
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
Let \(p(\mathcal{D}_n|\theta)\) denote the likelihood of observing our subsample \(\mathcal{D}_n\) under some model parameterized by \(\theta\). Then we typically want to maximize this likelihood with respect to the parameters (Murphy 2022):
\[\arg \max_{\theta} p(\mathcal{D}_n|\theta)\]
[…] deep neural networks are typically very underspecified by the available data, and […] parameters [therefore] correspond to a diverse variety of compelling explanations for the data. (Wilson 2020)
In this setting it is often crucial to treat models probabilistically!
Probabilistic models covered briefly today. More in my other talk
Ground Truthing
Probabilistic Models
Counterfactual Reasoning
We can now make predictions - great! But do we know how the predictions are actually made?
Let \(\hat\theta\) denote our MLE estimate (or MAP in the probabilistic setting). Then we are interested in understanding how predictions of our model change with respect to input changes.
\[\nabla_x p(y|x,\mathcal{D}_n)\]
Even though […] interpretability is of great importance and should be pursued, explanations can, in principle, be offered without opening the “black box”. (Wachter, Mittelstadt, and Russell 2017)
Objective originally proposed by Wachter, Mittelstadt, and Russell (2017) is as follows
\[ \min_{x\prime \in \mathcal{X}} h(x\prime) \ \ \ \mbox{s. t.} \ \ \ M(x\prime) = t \qquad(1)\]
where \(h\) relates to the complexity of the counterfactual and \(M\) denotes the classifier.
Typically this is approximated through regularization:
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) + \lambda h(x\prime) \qquad(2)\]
Yes and no!
While both are methodologically very similar, adversarial examples are meant to go undetected while CEs ought to be meaningful.
Effective counterfactuals should meet certain criteria ✅
NO!
Causal inference: counterfactuals are thought of as unobserved states of the world that we would like to observe in order to establish causality.
Counterfactual Explanations: involves perturbing features after some model has been trained.
But still … there is an intriguing link between the two domains.
When people say that counterfactuals should look realistic or plausible, they really mean that counterfactuals should be generated by the same Data Generating Process (DGP) as the factuals:
\[ x\prime \sim p(x) \]
But how do we estimate \(p(x)\)? Two probabilistic approaches …
Schut et al. (2021) note that by maximizing predictive probabilities \(\sigma(M(x\prime))\) for probabilistic models \(M\in\mathcal{\widetilde{M}}\) one implicitly minimizes epistemic and aleotoric uncertainty.
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) \ \ \ , \ \ \ M\in\mathcal{\widetilde{M}} \qquad(3)\]
Instead of perturbing samples directly, some have proposed to instead traverse a lower-dimensional latent embedding learned through a generative model (Joshi et al. 2019).
\[ z\prime = \arg \min_{z\prime} \ell(M(dec(z\prime)),t) + \lambda h(x\prime) \qquad(4)\]
and
\[x\prime = dec(z\prime)\]
where \(dec(\cdot)\) is the decoder function.
CounterfactualExplanations.jl 📦A unifying framework for generating Counterfactual Explanations.
Julia has an edge with respect to Trustworthy AI: it’s open-source, uniquely transparent and interoperable 🔴🟢🟣
Modular, composable, scalable!
Figure 6: Overview of package architecture. Modules are shown in red, structs in green and functions in blue.
using CounterfactualExplanations, Plots, GraphRecipes
plt = plot(AbstractGenerator, method=:tree, fontsize=10, nodeshape=:rect, size=(1000,700))
savefig(plt, joinpath(www_path,"generators.png"))Figure 7: Type tree for AbstractGenerator.
plt = plot(AbstractFittedModel, method=:tree, fontsize=10, nodeshape=:rect, size=(1000,700))
savefig(plt, joinpath(www_path,"models.png"))Figure 8: Type tree for AbstractFittedModel.
# Model
using CounterfactualExplanations.Models: LogisticModel
w = [1.0 1.0] # estimated coefficients
b = 0 # estimated bias
M = LogisticModel(w, [b])
# Select target class:
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
# Counterfactual search:
generator = GenericGenerator()
counterfactual = generate_counterfactual(
x, target, counterfactual_data, M, generator
)GenericGenerator.using LinearAlgebra
Σ = Symmetric(reshape(randn(9),3,3).*0.01 + UniformScaling(1)) # MAP covariance matrix
μ = hcat(b, w)
M = CounterfactualExplanations.Models.BayesianLogisticModel(μ, Σ)
# Select target class:
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
# Counterfactual search:
params = GreedyGeneratorParams(
δ = 0.5,
n = 10
)
generator = GreedyGenerator(;params=params)
counterfactual = generate_counterfactual(
x, target, counterfactual_data, M, generator
)GreedyGenerator.… instantiating model and attaching VAE.
The results in Figure 13 look great!
But things can also go wrong …
The VAE used to generate the counterfactual in Figure 14 is not expressive enough.
The counterfactual in Figure 15 is also valid … what to do?
Step 1: add composite type as subtype of AbstractFittedModel.
Step 2: dispatch logits and probs methods for new model type.
using Statistics
import CounterfactualExplanations.Models: logits, probs
logits(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([nn(X) for nn in M.ensemble],3), dims=3)
probs(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([softmax(nn(X)) for nn in M.ensemble],3),dims=3)
M = FittedEnsemble(ensemble)Results for a simple deep ensemble also look convincing!
Adding support for torch models was easy! Here’s how I implemented it for torch classifiers trained in R.
Step 1: add composite type as subtype of AbstractFittedModel
Done here.
Step 2: dispatch logits and probs methods for new model type.
Done here.
Step 3: add gradient access.
Done here.
using RCall
synthetic = load_synthetic([:r_torch])
model = synthetic[:classification_binary][:models][:r_torch][:raw_model]
M = RTorchModel(model)
# Define generator:
generator = GenericGenerator()
# Generate recourse:
counterfactual = generate_counterfactual(
x, target, counterfactual_data, M, generator
)Idea 💡: let’s implement a generic generator with dropout!
Step 1: create a subtype of AbstractGradientBasedGenerator (adhering to some basic rules).
# Constructor:
struct DropoutGenerator <: AbstractGradientBasedGenerator
loss::Symbol # loss function
complexity::Function # complexity function
mutability::Union{Nothing,Vector{Symbol}} # mutibility constraints
λ::AbstractFloat # strength of penalty
ϵ::AbstractFloat # step size
τ::AbstractFloat # tolerance for convergence
p_dropout::AbstractFloat # dropout rate
endStep 2: implement logic for generating perturbations.
import CounterfactualExplanations.Generators: generate_perturbations, ∇
using StatsBase
function generate_perturbations(generator::AbstractDropoutGenerator, counterfactual_state::CounterfactualState)
𝐠ₜ = ∇(generator, counterfactual_state) # gradient
# Dropout:
set_to_zero = sample(1:length(𝐠ₜ),Int(round(generator.p_dropout*length(𝐠ₜ))),replace=false)
𝐠ₜ[set_to_zero] .= 0
Δx′ = - (generator.ϵ .* 𝐠ₜ) # gradient step
return Δx′
endDevelop package, register and submit to JuliaCon 2022.
Native support for deep learning models (Flux, torch).
Add latent space search.
MLJ, GLM, …Flux optimizers.JuliaCon 2022 - Explaining Black-Box Models through Counterfactuals